require "stringio"
require_relative "polyfill/unpack1"

module Prism
  # A module responsible for deserializing parse results.
  module Serialize
    # The major version of prism that we are expecting to find in the serialized
    # strings.
    MAJOR_VERSION = 0

    # The minor version of prism that we are expecting to find in the serialized
    # strings.
    MINOR_VERSION = 30

    # The patch version of prism that we are expecting to find in the serialized
    # strings.
    PATCH_VERSION = 0

    # Deserialize the AST represented by the given string into a parse result.
    def self.load(input, serialized)
      input = input.dup
      source = Source.for(input)
      loader = Loader.new(source, serialized)
      result = loader.load_result

      input.force_encoding(loader.encoding)
      result
    end

    # Deserialize the tokens represented by the given string into a parse
    # result.
    def self.load_tokens(source, serialized)
      Loader.new(source, serialized).load_tokens_result
    end

    class Loader # :nodoc:
      if RUBY_ENGINE == "truffleruby"
        # StringIO is synchronized and that adds a high overhead on TruffleRuby.
        class FastStringIO # :nodoc:
          attr_accessor :pos

          def initialize(string)
            @string = string
            @pos = 0
          end

          def getbyte
            byte = @string.getbyte(@pos)
            @pos += 1
            byte
          end

          def read(n)
            slice = @string.byteslice(@pos, n)
            @pos += n
            slice
          end

          def eof?
            @pos >= @string.bytesize
          end
        end
      else
        FastStringIO = ::StringIO
      end
      private_constant :FastStringIO

      attr_reader :encoding, :input, :serialized, :io
      attr_reader :constant_pool_offset, :constant_pool, :source
      attr_reader :start_line

      def initialize(source, serialized)
        @encoding = Encoding::UTF_8

        @input = source.source.dup
        raise unless serialized.encoding == Encoding::BINARY
        @serialized = serialized
        @io = FastStringIO.new(serialized)

        @constant_pool_offset = nil
        @constant_pool = nil

        @source = source
        define_load_node_lambdas unless RUBY_ENGINE == "ruby"
      end

      def load_header
        raise "Invalid serialization" if io.read(5) != "PRISM"
        raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
        only_semantic_fields = io.getbyte
        unless only_semantic_fields == 0
          raise "Invalid serialization (location fields must be included but are not)"
        end
      end

      def load_encoding
        @encoding = Encoding.find(io.read(load_varuint))
        @input = input.force_encoding(@encoding).freeze
        @encoding
      end

      def load_start_line
        source.instance_variable_set :@start_line, load_varsint
      end

      def load_line_offsets
        source.instance_variable_set :@offsets, Array.new(load_varuint) { load_varuint }
      end

      def load_comments
        Array.new(load_varuint) do
          case load_varuint
          when 0 then InlineComment.new(load_location_object)
          when 1 then EmbDocComment.new(load_location_object)
          end
        end
      end

      DIAGNOSTIC_TYPES = [
        <%- errors.each do |error| -%>
        <%= error.name.downcase.to_sym.inspect %>,
        <%- end -%>
        <%- warnings.each do |warning| -%>
        <%= warning.name.downcase.to_sym.inspect %>,
        <%- end -%>
      ].freeze

      private_constant :DIAGNOSTIC_TYPES

      def load_metadata
        comments = load_comments
        magic_comments = Array.new(load_varuint) { MagicComment.new(load_location_object, load_location_object) }
        data_loc = load_optional_location_object
        errors = Array.new(load_varuint) { ParseError.new(DIAGNOSTIC_TYPES.fetch(load_varuint), load_embedded_string, load_location_object, load_error_level) }
        warnings = Array.new(load_varuint) { ParseWarning.new(DIAGNOSTIC_TYPES.fetch(load_varuint), load_embedded_string, load_location_object, load_warning_level) }
        [comments, magic_comments, data_loc, errors, warnings]
      end

      def load_tokens
        tokens = []
        while type = TOKEN_TYPES.fetch(load_varuint)
          start = load_varuint
          length = load_varuint
          lex_state = load_varuint
          location = Location.new(@source, start, length)
          tokens << [Token.new(source, type, location.slice, location), lex_state]
        end

        tokens
      end

      def load_tokens_result
        tokens = load_tokens
        encoding = load_encoding
        load_start_line
        load_line_offsets
        comments, magic_comments, data_loc, errors, warnings = load_metadata
        tokens.each { |token,| token.value.force_encoding(encoding) }

        raise "Expected to consume all bytes while deserializing" unless @io.eof?
        LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source)
      end

      def load_nodes
        load_header
        load_encoding
        load_start_line
        load_line_offsets

        comments, magic_comments, data_loc, errors, warnings = load_metadata

        @constant_pool_offset = load_uint32
        @constant_pool = Array.new(load_varuint, nil)

        [load_node, comments, magic_comments, data_loc, errors, warnings]
      end

      def load_result
        node, comments, magic_comments, data_loc, errors, warnings = load_nodes
        ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source)
      end

      private

      # variable-length integer using https://en.wikipedia.org/wiki/LEB128
      # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints
      def load_varuint
        n = io.getbyte
        if n < 128
          n
        else
          n -= 128
          shift = 0
          while (b = io.getbyte) >= 128
            n += (b - 128) << (shift += 7)
          end
          n + (b << (shift + 7))
        end
      end

      def load_varsint
        n = load_varuint
        (n >> 1) ^ (-(n & 1))
      end

      def load_integer
        negative = io.getbyte != 0
        length = load_varuint

        value = 0
        length.times { |index| value |= (load_varuint << (index * 32)) }

        value = -value if negative
        value
      end

      def load_double
        io.read(8).unpack1("D")
      end

      def load_uint32
        io.read(4).unpack1("L")
      end

      def load_optional_node
        if io.getbyte != 0
          io.pos -= 1
          load_node
        end
      end

      def load_embedded_string
        io.read(load_varuint).force_encoding(encoding)
      end

      def load_string
        type = io.getbyte
        case type
        when 1
          input.byteslice(load_varuint, load_varuint).force_encoding(encoding)
        when 2
          load_embedded_string
        else
          raise "Unknown serialized string type: #{type}"
        end
      end

      def load_location
        (load_varuint << 32) | load_varuint
      end

      def load_location_object
        Location.new(source, load_varuint, load_varuint)
      end

      def load_optional_location
        load_location if io.getbyte != 0
      end

      def load_optional_location_object
        load_location_object if io.getbyte != 0
      end

      def load_constant(index)
        constant = constant_pool[index]

        unless constant
          offset = constant_pool_offset + index * 8
          start = @serialized.unpack1("L", offset: offset)
          length = @serialized.unpack1("L", offset: offset + 4)

          constant =
            if start.nobits?(1 << 31)
              input.byteslice(start, length).force_encoding(@encoding).to_sym
            else
              @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym
            end

          constant_pool[index] = constant
        end

        constant
      end

      def load_required_constant
        load_constant(load_varuint - 1)
      end

      def load_optional_constant
        index = load_varuint
        load_constant(index - 1) if index != 0
      end

      def load_error_level
        level = io.getbyte

        case level
        when 0
          :syntax
        when 1
          :argument
        when 2
          :load
        else
          raise "Unknown level: #{level}"
        end
      end

      def load_warning_level
        level = io.getbyte

        case level
        when 0
          :default
        when 1
          :verbose
        else
          raise "Unknown level: #{level}"
        end
      end

      if RUBY_ENGINE == "ruby"
        def load_node
          type = io.getbyte
          node_id = load_varuint
          location = load_location

          case type
          <%- nodes.each_with_index do |node, index| -%>
          when <%= index + 1 %> then
            <%- if node.needs_serialized_length? -%>
            load_uint32
            <%- end -%>
            <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
              case field
              when Prism::Template::NodeField then "load_node"
              when Prism::Template::OptionalNodeField then "load_optional_node"
              when Prism::Template::StringField then "load_string"
              when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
              when Prism::Template::ConstantField then "load_required_constant"
              when Prism::Template::OptionalConstantField then "load_optional_constant"
              when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
              when Prism::Template::LocationField then "load_location"
              when Prism::Template::OptionalLocationField then "load_optional_location"
              when Prism::Template::UInt8Field then "io.getbyte"
              when Prism::Template::UInt32Field then "load_varuint"
              when Prism::Template::IntegerField then "load_integer"
              when Prism::Template::DoubleField then "load_double"
              else raise
              end
            }].join(", ") -%>)
            <%- end -%>
          end
        end
      else
        def load_node
          type = io.getbyte
          @load_node_lambdas[type].call
        end

        def define_load_node_lambdas
          @load_node_lambdas = [
            nil,
            <%- nodes.each do |node| -%>
            -> {
              node_id = load_varuint
              location = load_location
              <%- if node.needs_serialized_length? -%>
              load_uint32
              <%- end -%>
              <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
                case field
                when Prism::Template::NodeField then "load_node"
                when Prism::Template::OptionalNodeField then "load_optional_node"
                when Prism::Template::StringField then "load_string"
                when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
                when Prism::Template::ConstantField then "load_required_constant"
                when Prism::Template::OptionalConstantField then "load_optional_constant"
                when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
                when Prism::Template::LocationField then "load_location"
                when Prism::Template::OptionalLocationField then "load_optional_location"
                when Prism::Template::UInt8Field then "io.getbyte"
                when Prism::Template::UInt32Field then "load_varuint"
                when Prism::Template::IntegerField then "load_integer"
                when Prism::Template::DoubleField then "load_double"
                else raise
                end
              }].join(", ") -%>)
            },
            <%- end -%>
          ]
        end
      end
    end

    # The token types that can be indexed by their enum values.
    TOKEN_TYPES = [
      nil,
      <%- tokens.each do |token| -%>
      <%= token.name.to_sym.inspect %>,
      <%- end -%>
    ]
  end
end
